import numpy as np
import prettytable as pt

algorithms = [
    'ML',
    'SD',
    'MF',
    'ZF',
    'MMSE',
    'ZF_SIC',
    'MMSE_SIC',
    'MF_SIC',
    'LR_ZF',
    'LR_MMSE',
    'MMSE_LAS',
]

Mod = {

    'QAM_4' : {
        'MOD_NAME': 'QAM_4',
        'CONSTELLATION': np.array([-1, 1], dtype=np.float32),
        'M' : 4,  # Number of complex constellation
        'Mr': 2,
    },

    'PSK_8' : {
        'MOD_NAME': 'PSK_8',
        'CONSTELLATION': np.array([-1, -0.7071, 0, 0.7071, 1], dtype=np.float32),
        'CONSTELLATION_COMPLEX': np.array([0.-1.j, -0.7071-0.7071j, 0.7071-0.7071j, -1.+0.j, 
                                           1.+0.j, -0.7071+0.7071j, 0.7071+0.7071j,  0.+1.j], dtype=np.complex64),
        'M' : 8,
        'Mr': 5,
    },

    'QAM_16' : {
        'MOD_NAME': 'QAM_16',
        'CONSTELLATION': np.array([-3,-1,1,3], dtype=np.float32),
        'M' : 16,  # Number of complex constellation
        'Mr': 4,
    },

    'QAM_64' : {
        'MOD_NAME': 'QAM_64',
        'CONSTELLATION': np.array([-7,-5,-3,-1,1,3,5,7], dtype=np.float32),
        'M' : 64,  # Number of complex constellation
        'Mr': 8,
    },
}

antenna = {
    '5x10'  : {'Nt': 5,  'Nr': 10},
    '15x25' : {'Nt': 15, 'Nr': 25},
    '20x25' : {'Nt': 20, 'Nr': 25},
    '30x60' : {'Nt': 30, 'Nr': 60},
}

SNR = {
    '8-13'   : {'SNR_dB_MIN': 8 ,'SNR_dB_MAX': 13,'SNR_dB_TRAIN_MIN': 7 ,'SNR_dB_TRAIN_MAX': 14},
}

common = {'BATCH_SIZE': 5000,'MIN_ERROR_SYMBOLS': 100000,}

def constellation_complex_gen(trainparams):
    if trainparams['MOD_NAME'] != 'PSK_8':
        tmp = np.reshape(trainparams['CONSTELLATION'], [1, -1])
        trainparams['CONSTELLATION_COMPLEX'] = np.reshape( tmp + 1j*np.transpose(tmp), [-1]) # [M]
    else :
        pass

def choose(dicname,tb,params):
    dic = eval(dicname)
    tb.clear()
    tb.field_names = ['index'] + list(range(len(dic.keys())))
    tb.add_row([dicname] + list(dic.keys()))
    print(tb) 
    params.update(dic[list(dic.keys())[eval(input('Tell me your choice: '))]])

def updateL(params):
        params['L'] = 30

def modinit():

    print('Init')
    params = {}
    tb = pt.PrettyTable()
    # algorithm
    tb.field_names = ['index'] + list(range(len(algorithms)))
    tb.add_row(['algorithms'] + algorithms)
    print(tb)
    params['algorithms'] = algorithms[eval(input('Tell me your choice: '))]
    updateL(params)

    choose('Mod',tb,params)
    choose('antenna',tb,params)
    choose('SNR',tb,params)
    params.update(common)
    constellation_complex_gen(params)
    
    # fix
    while True:
        tb.clear()
        tb.field_names = list(params.keys())
        tb.add_row(list(params.values()))
        tb.del_column('CONSTELLATION_COMPLEX')
        print(tb)

        fix = input('Is there anything else need to be fixed?(n/index)')
        if fix == '' or fix == 'n':
            break
        else:
            try:
                params[list(params.keys())[eval(fix)]] = eval(input('Tell me your num: '))
            except:
                print('out of range! fixing failed')

    # final
    tb.clear()
    tb.field_names = list(params.keys())
    tb.add_row(list(params.values()))
    tb.del_column('CONSTELLATION_COMPLEX')
    print('\nFinally Settings:')
    print(tb)

    tb.clear()
    tb.field_names = ['CONSTELLATION_COMPLEX']
    tb.add_row([params['CONSTELLATION_COMPLEX']])
    print(tb)

    return params

if __name__ == '__main__':
    modinit()